amg.initialize -> _process_crop -> process_batch -> _to_mask_data & predict_torch
(test-micro-sam-mps) genevieb@192-168-1-102 development % kernprof -lv benchmark.py -m vit_h -d cpu -e -p
Running benchmarks for vit_h
with device: cpu
Running benchmark_amg ...
| model | device | benchmark | runtime |
|:--------|:---------|:------------|----------:|
| vit_h | cpu | amg | 51.9006 |
Wrote profile results to benchmark.py.lprof
Timer unit: 1e-06 s
Total time: 4.04613 s
File: /Users/genevieb/Documents/GitHub/micro-sam/micro_sam/instance_segmentation.py
Function: _to_mask_data at line 255
Line # Hits Time Per Hit % Time Line Contents
==============================================================
255 @profile
256 def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None):
257 16 5.0 0.3 0.0 orig_h, orig_w = original_size
258
259 # serialize predictions and store in MaskData
260 16 1214.0 75.9 0.0 data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1))
261 16 5.0 0.3 0.0 if points is not None:
262 16 964.0 60.2 0.0 data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float)
263
264 16 0.0 0.0 0.0 del masks
265
266 # calculate the stability scores
267 32 897476.0 28046.1 22.2 data["stability_score"] = amg_utils.calculate_stability_score(
268 16 145.0 9.1 0.0 data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset
269 )
270
271 # threshold masks and calculate boxes
272 16 153194.0 9574.6 3.8 data["masks"] = data["masks"] > self._predictor.model.mask_threshold
273 16 268.0 16.8 0.0 data["masks"] = data["masks"].type(torch.bool)
274 16 370802.0 23175.1 9.2 data["boxes"] = batched_mask_to_box(data["masks"])
275
276 # compress to RLE
277 16 90.0 5.6 0.0 data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
278 16 2611240.0 163202.5 64.5 data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"])
279 16 10724.0 670.2 0.3 del data["masks"]
280
281 16 1.0 0.1 0.0 return data
Total time: 44.9853 s
File: /Users/genevieb/Documents/GitHub/micro-sam/micro_sam/instance_segmentation.py
Function: _process_batch at line 366
Line # Hits Time Per Hit % Time Line Contents
==============================================================
366 @profile
367 def _process_batch(self, points, im_size, crop_box, original_size):
368 # run model on this batch
369 16 1258.0 78.6 0.0 transformed_points = self._predictor.transform.apply_coords(points, im_size)
370 16 552.0 34.5 0.0 in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float)
371 16 428.0 26.8 0.0 in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
372 32 40890910.0 1e+06 90.9 masks, iou_preds, _ = self._predictor.predict_torch(
373 16 303.0 18.9 0.0 in_points[:, None, :],
374 16 33.0 2.1 0.0 in_labels[:, None],
375 16 1.0 0.1 0.0 multimask_output=True,
376 16 1.0 0.1 0.0 return_logits=True,
377 )
378 16 4046698.0 252918.6 9.0 data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points)
379 16 45095.0 2818.4 0.1 del masks
380 16 5.0 0.3 0.0 return data
Total time: 50.0068 s
File: /Users/genevieb/Documents/GitHub/micro-sam/micro_sam/instance_segmentation.py
Function: _process_crop at line 382
Line # Hits Time Per Hit % Time Line Contents
==============================================================
382 @profile
383 def _process_crop(self, image, crop_box, crop_layer_idx, verbose, precomputed_embeddings):
384 # crop the image and calculate embeddings
385 1 1.0 1.0 0.0 x0, y0, x1, y1 = crop_box
386 1 4.0 4.0 0.0 cropped_im = image[y0:y1, x0:x1, :]
387 1 1.0 1.0 0.0 cropped_im_size = cropped_im.shape[:2]
388
389 1 0.0 0.0 0.0 if not precomputed_embeddings:
390 self._predictor.set_image(cropped_im)
391
392 # get the points for this crop
393 1 8.0 8.0 0.0 points_scale = np.array(cropped_im_size)[None, ::-1]
394 1 187.0 187.0 0.0 points_for_image = self.point_grids[crop_layer_idx] * points_scale
395
396 # generate masks for this crop in batches
397 1 27.0 27.0 0.0 data = amg_utils.MaskData()
398 2 1.0 0.5 0.0 n_batches = len(points_for_image) // self._points_per_batch +\
399 1 1.0 1.0 0.0 int(len(points_for_image) % self._points_per_batch != 0)
400 18 3620.0 201.1 0.0 for (points,) in tqdm(
401 1 3.0 3.0 0.0 amg_utils.batch_iterator(self._points_per_batch, points_for_image),
402 1 0.0 0.0 0.0 disable=not verbose, total=n_batches,
403 1 1.0 1.0 0.0 desc="Predict masks for point grid prompts",
404 ):
405 16 44996183.0 3e+06 90.0 batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size)
406 16 4998693.0 312418.3 10.0 data.cat(batch_data)
407 16 8108.0 506.8 0.0 del batch_data
408
409 1 0.0 0.0 0.0 if not precomputed_embeddings:
410 self._predictor.reset_image()
411
412 1 0.0 0.0 0.0 return data
Total time: 40.8883 s
File: /Users/genevieb/mambaforge/envs/test-micro-sam-mps/lib/python3.10/site-packages/segment_anything/predictor.py
Function: predict_torch at line 168
Line # Hits Time Per Hit % Time Line Contents
==============================================================
168 @torch.no_grad()
169 @profile
170 def predict_torch(
171 self,
172 point_coords: Optional[torch.Tensor],
173 point_labels: Optional[torch.Tensor],
174 boxes: Optional[torch.Tensor] = None,
175 mask_input: Optional[torch.Tensor] = None,
176 multimask_output: bool = True,
177 return_logits: bool = False,
178 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
179 """
180 Predict masks for the given input prompts, using the currently set image.
181 Input prompts are batched torch tensors and are expected to already be
182 transformed to the input frame using ResizeLongestSide.
183
184 Arguments:
185 point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
186 model. Each point is in (X,Y) in pixels.
187 point_labels (torch.Tensor or None): A BxN array of labels for the
188 point prompts. 1 indicates a foreground point and 0 indicates a
189 background point.
190 boxes (np.ndarray or None): A Bx4 array given a box prompt to the
191 model, in XYXY format.
192 mask_input (np.ndarray): A low resolution mask input to the model, typically
193 coming from a previous prediction iteration. Has form Bx1xHxW, where
194 for SAM, H=W=256. Masks returned by a previous iteration of the
195 predict method do not need further transformation.
196 multimask_output (bool): If true, the model will return three masks.
197 For ambiguous input prompts (such as a single click), this will often
198 produce better masks than a single prediction. If only a single
199 mask is needed, the model's predicted quality score can be used
200 to select the best mask. For non-ambiguous prompts, such as multiple
201 input prompts, multimask_output=False can give better results.
202 return_logits (bool): If true, returns un-thresholded masks logits
203 instead of a binary mask.
204
205 Returns:
206 (torch.Tensor): The output masks in BxCxHxW format, where C is the
207 number of masks, and (H, W) is the original image size.
208 (torch.Tensor): An array of shape BxC containing the model's
209 predictions for the quality of each mask.
210 (torch.Tensor): An array of shape BxCxHxW, where C is the number
211 of masks and H=W=256. These low res logits can be passed to
212 a subsequent iteration as mask input.
213 """
214 16 369.0 23.1 0.0 if not self.is_image_set:
215 raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
216
217 16 3.0 0.2 0.0 if point_coords is not None:
218 16 3.0 0.2 0.0 points = (point_coords, point_labels)
219 else:
220 points = None
221
222 # Embed prompts
223 32 18334.0 572.9 0.0 sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
224 16 0.0 0.0 0.0 points=points,
225 16 0.0 0.0 0.0 boxes=boxes,
226 16 0.0 0.0 0.0 masks=mask_input,
227 )
228
229 # Predict masks
230 32 38480121.0 1e+06 94.1 low_res_masks, iou_predictions = self.model.mask_decoder(
231 16 6.0 0.4 0.0 image_embeddings=self.features,
232 16 41064.0 2566.5 0.1 image_pe=self.model.prompt_encoder.get_dense_pe(),
233 16 0.0 0.0 0.0 sparse_prompt_embeddings=sparse_embeddings,
234 16 2.0 0.1 0.0 dense_prompt_embeddings=dense_embeddings,
235 16 3.0 0.2 0.0 multimask_output=multimask_output,
236 )
237
238 # Upscale the masks to the original image resolution
239 16 2348373.0 146773.3 5.7 masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
240
241 16 7.0 0.4 0.0 if not return_logits:
242 masks = masks > self.model.mask_threshold
243
244 16 8.0 0.5 0.0 return masks, iou_predictions, low_res_masks
(test-micro-sam-mps) genevieb@192-168-1-102 development % kernprof -lv benchmark.py -m vit_h -d mps -e -p
Running benchmarks for vit_h
with device: mps
Running benchmark_amg ...
/Users/genevieb/mambaforge/envs/test-micro-sam-mps/lib/python3.10/site-packages/segment_anything/modeling/mask_decoder.py:126: UserWarning: MPS: no support for int64 repeats mask, casting it to int32 (Triggered internally at /Users/runner/work/_temp/anaconda/conda-bld/pytorch_1682343686130/work/aten/src/ATen/native/mps/operations/Repeat.mm:236.)
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
[W MPSFallback.mm:11] Warning: The operator 'torchvision::nms' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (function operator())
| model | device | benchmark | runtime |
|:--------|:---------|:------------|----------:|
| vit_h | mps | amg | 346.799 |
Wrote profile results to benchmark.py.lprof
Timer unit: 1e-06 s
Total time: 123.62 s
File: /Users/genevieb/Documents/GitHub/micro-sam/micro_sam/instance_segmentation.py
Function: _to_mask_data at line 255
Line # Hits Time Per Hit % Time Line Contents
==============================================================
255 @profile
256 def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None):
257 16 169.0 10.6 0.0 orig_h, orig_w = original_size
258
259 # serialize predictions and store in MaskData
260 16 88827.0 5551.7 0.1 data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1))
261 16 20.0 1.2 0.0 if points is not None:
262 16 71025.0 4439.1 0.1 data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float)
263
264 16 6.0 0.4 0.0 del masks
265
266 # calculate the stability scores
267 32 858545.0 26829.5 0.7 data["stability_score"] = amg_utils.calculate_stability_score(
268 16 33751.0 2109.4 0.0 data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset
269 )
270
271 # threshold masks and calculate boxes
272 16 15019.0 938.7 0.0 data["masks"] = data["masks"] > self._predictor.model.mask_threshold
273 16 269.0 16.8 0.0 data["masks"] = data["masks"].type(torch.bool)
274 16 2740052.0 171253.2 2.2 data["boxes"] = batched_mask_to_box(data["masks"])
275
276 # compress to RLE
277 16 5700.0 356.2 0.0 data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
278 16 119806776.0 7e+06 96.9 data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"])
279 16 246.0 15.4 0.0 del data["masks"]
280
281 16 2.0 0.1 0.0 return data
Total time: 336.425 s
File: /Users/genevieb/Documents/GitHub/micro-sam/micro_sam/instance_segmentation.py
Function: _process_batch at line 366
Line # Hits Time Per Hit % Time Line Contents
==============================================================
366 @profile
367 def _process_batch(self, points, im_size, crop_box, original_size):
368 # run model on this batch
369 16 25639.0 1602.4 0.0 transformed_points = self._predictor.transform.apply_coords(points, im_size)
370 16 71020.0 4438.8 0.0 in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float)
371 16 42514.0 2657.1 0.0 in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
372 32 212613510.0 7e+06 63.2 masks, iou_preds, _ = self._predictor.predict_torch(
373 16 970.0 60.6 0.0 in_points[:, None, :],
374 16 51.0 3.2 0.0 in_labels[:, None],
375 16 2.0 0.1 0.0 multimask_output=True,
376 16 0.0 0.0 0.0 return_logits=True,
377 )
378 16 123671154.0 8e+06 36.8 data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points)
379 16 37.0 2.3 0.0 del masks
380 16 2.0 0.1 0.0 return data
Total time: 343.198 s
File: /Users/genevieb/Documents/GitHub/micro-sam/micro_sam/instance_segmentation.py
Function: _process_crop at line 382
Line # Hits Time Per Hit % Time Line Contents
==============================================================
382 @profile
383 def _process_crop(self, image, crop_box, crop_layer_idx, verbose, precomputed_embeddings):
384 # crop the image and calculate embeddings
385 1 1.0 1.0 0.0 x0, y0, x1, y1 = crop_box
386 1 2.0 2.0 0.0 cropped_im = image[y0:y1, x0:x1, :]
387 1 1.0 1.0 0.0 cropped_im_size = cropped_im.shape[:2]
388
389 1 0.0 0.0 0.0 if not precomputed_embeddings:
390 self._predictor.set_image(cropped_im)
391
392 # get the points for this crop
393 1 10.0 10.0 0.0 points_scale = np.array(cropped_im_size)[None, ::-1]
394 1 185.0 185.0 0.0 points_for_image = self.point_grids[crop_layer_idx] * points_scale
395
396 # generate masks for this crop in batches
397 1 38.0 38.0 0.0 data = amg_utils.MaskData()
398 2 2.0 1.0 0.0 n_batches = len(points_for_image) // self._points_per_batch +\
399 1 1.0 1.0 0.0 int(len(points_for_image) % self._points_per_batch != 0)
400 18 9459.0 525.5 0.0 for (points,) in tqdm(
401 1 3.0 3.0 0.0 amg_utils.batch_iterator(self._points_per_batch, points_for_image),
402 1 0.0 0.0 0.0 disable=not verbose, total=n_batches,
403 1 1.0 1.0 0.0 desc="Predict masks for point grid prompts",
404 ):
405 16 336429279.0 2e+07 98.0 batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size)
406 16 6530072.0 408129.5 1.9 data.cat(batch_data)
407 16 228971.0 14310.7 0.1 del batch_data
408
409 1 0.0 0.0 0.0 if not precomputed_embeddings:
410 self._predictor.reset_image()
411
412 1 0.0 0.0 0.0 return data
Total time: 212.395 s
File: /Users/genevieb/mambaforge/envs/test-micro-sam-mps/lib/python3.10/site-packages/segment_anything/predictor.py
Function: predict_torch at line 168
Line # Hits Time Per Hit % Time Line Contents
==============================================================
168 @torch.no_grad()
169 @profile
170 def predict_torch(
171 self,
172 point_coords: Optional[torch.Tensor],
173 point_labels: Optional[torch.Tensor],
174 boxes: Optional[torch.Tensor] = None,
175 mask_input: Optional[torch.Tensor] = None,
176 multimask_output: bool = True,
177 return_logits: bool = False,
178 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
179 """
180 Predict masks for the given input prompts, using the currently set image.
181 Input prompts are batched torch tensors and are expected to already be
182 transformed to the input frame using ResizeLongestSide.
183
184 Arguments:
185 point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
186 model. Each point is in (X,Y) in pixels.
187 point_labels (torch.Tensor or None): A BxN array of labels for the
188 point prompts. 1 indicates a foreground point and 0 indicates a
189 background point.
190 boxes (np.ndarray or None): A Bx4 array given a box prompt to the
191 model, in XYXY format.
192 mask_input (np.ndarray): A low resolution mask input to the model, typically
193 coming from a previous prediction iteration. Has form Bx1xHxW, where
194 for SAM, H=W=256. Masks returned by a previous iteration of the
195 predict method do not need further transformation.
196 multimask_output (bool): If true, the model will return three masks.
197 For ambiguous input prompts (such as a single click), this will often
198 produce better masks than a single prediction. If only a single
199 mask is needed, the model's predicted quality score can be used
200 to select the best mask. For non-ambiguous prompts, such as multiple
201 input prompts, multimask_output=False can give better results.
202 return_logits (bool): If true, returns un-thresholded masks logits
203 instead of a binary mask.
204
205 Returns:
206 (torch.Tensor): The output masks in BxCxHxW format, where C is the
207 number of masks, and (H, W) is the original image size.
208 (torch.Tensor): An array of shape BxC containing the model's
209 predictions for the quality of each mask.
210 (torch.Tensor): An array of shape BxCxHxW, where C is the number
211 of masks and H=W=256. These low res logits can be passed to
212 a subsequent iteration as mask input.
213 """
214 16 12.0 0.8 0.0 if not self.is_image_set:
215 raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
216
217 16 7.0 0.4 0.0 if point_coords is not None:
218 16 3.0 0.2 0.0 points = (point_coords, point_labels)
219 else:
220 points = None
221
222 # Embed prompts
223 32 471902.0 14746.9 0.2 sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
224 16 3.0 0.2 0.0 points=points,
225 16 2.0 0.1 0.0 boxes=boxes,
226 16 1.0 0.1 0.0 masks=mask_input,
227 )
228
229 # Predict masks
230 32 142891626.0 4e+06 67.3 low_res_masks, iou_predictions = self.model.mask_decoder(
231 16 1.0 0.1 0.0 image_embeddings=self.features,
232 16 209664.0 13104.0 0.1 image_pe=self.model.prompt_encoder.get_dense_pe(),
233 16 3.0 0.2 0.0 sparse_prompt_embeddings=sparse_embeddings,
234 16 2.0 0.1 0.0 dense_prompt_embeddings=dense_embeddings,
235 16 1.0 0.1 0.0 multimask_output=multimask_output,
236 )
237
238 # Upscale the masks to the original image resolution
239 16 68821903.0 4e+06 32.4 masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
240
241 16 92.0 5.8 0.0 if not return_logits:
242 masks = masks > self.model.mask_threshold
243
244 16 54.0 3.4 0.0 return masks, iou_predictions, low_res_masks
amg.initialize -> _process_crop -> _process_batch
(test-micro-sam-mps) genevieb@192-168-1-102 development % kernprof -lv benchmark.py -m vit_h -d cpu -e -p
Running benchmarks for vit_h
with device: cpu
Running benchmark_amg ...
| model | device | benchmark | runtime |
|:--------|:---------|:------------|----------:|
| vit_h | cpu | amg | 45.2618 |
Wrote profile results to benchmark.py.lprof
Timer unit: 1e-06 s
Total time: 38.4909 s
File: /Users/genevieb/Documents/GitHub/micro-sam/micro_sam/instance_segmentation.py
Function: _process_batch at line 365
Line # Hits Time Per Hit % Time Line Contents
==============================================================
365 @profile
366 def _process_batch(self, points, im_size, crop_box, original_size):
367 # run model on this batch
368 16 911.0 56.9 0.0 transformed_points = self._predictor.transform.apply_coords(points, im_size)
369 16 875.0 54.7 0.0 in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float)
370 16 393.0 24.6 0.0 in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
371 32 34451581.0 1e+06 89.5 masks, iou_preds, _ = self._predictor.predict_torch(
372 16 297.0 18.6 0.0 in_points[:, None, :],
373 16 31.0 1.9 0.0 in_labels[:, None],
374 16 2.0 0.1 0.0 multimask_output=True,
375 16 0.0 0.0 0.0 return_logits=True,
376 )
377 16 4001238.0 250077.4 10.4 data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points)
378 16 35558.0 2222.4 0.1 del masks
379 16 4.0 0.2 0.0 return data
Total time: 43.4143 s
File: /Users/genevieb/Documents/GitHub/micro-sam/micro_sam/instance_segmentation.py
Function: _process_crop at line 381
Line # Hits Time Per Hit % Time Line Contents
==============================================================
381 @profile
382 def _process_crop(self, image, crop_box, crop_layer_idx, verbose, precomputed_embeddings):
383 # crop the image and calculate embeddings
384 1 0.0 0.0 0.0 x0, y0, x1, y1 = crop_box
385 1 1.0 1.0 0.0 cropped_im = image[y0:y1, x0:x1, :]
386 1 0.0 0.0 0.0 cropped_im_size = cropped_im.shape[:2]
387
388 1 0.0 0.0 0.0 if not precomputed_embeddings:
389 self._predictor.set_image(cropped_im)
390
391 # get the points for this crop
392 1 7.0 7.0 0.0 points_scale = np.array(cropped_im_size)[None, ::-1]
393 1 16.0 16.0 0.0 points_for_image = self.point_grids[crop_layer_idx] * points_scale
394
395 # generate masks for this crop in batches
396 1 37.0 37.0 0.0 data = amg_utils.MaskData()
397 2 1.0 0.5 0.0 n_batches = len(points_for_image) // self._points_per_batch +\
398 1 0.0 0.0 0.0 int(len(points_for_image) % self._points_per_batch != 0)
399 18 1977.0 109.8 0.0 for (points,) in tqdm(
400 1 4.0 4.0 0.0 amg_utils.batch_iterator(self._points_per_batch, points_for_image),
401 1 0.0 0.0 0.0 disable=not verbose, total=n_batches,
402 1 1.0 1.0 0.0 desc="Predict masks for point grid prompts",
403 ):
404 16 38501125.0 2e+06 88.7 batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size)
405 16 4903074.0 306442.1 11.3 data.cat(batch_data)
406 16 8065.0 504.1 0.0 del batch_data
407
408 1 0.0 0.0 0.0 if not precomputed_embeddings:
409 self._predictor.reset_image()
410
411 1 0.0 0.0 0.0 return data
(test-micro-sam-mps) genevieb@192-168-1-102 development % kernprof -lv benchmark.py -m vit_h -d mps -e -p
Running benchmarks for vit_h
with device: mps
Running benchmark_amg ...
/Users/genevieb/mambaforge/envs/test-micro-sam-mps/lib/python3.10/site-packages/segment_anything/modeling/mask_decoder.py:126: UserWarning: MPS: no support for int64 repeats mask, casting it to int32 (Triggered internally at /Users/runner/work/_temp/anaconda/conda-bld/pytorch_1682343686130/work/aten/src/ATen/native/mps/operations/Repeat.mm:236.)
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
[W MPSFallback.mm:11] Warning: The operator 'torchvision::nms' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (function operator())
| model | device | benchmark | runtime |
|:--------|:---------|:------------|----------:|
| vit_h | mps | amg | 249.157 |
Wrote profile results to benchmark.py.lprof
Timer unit: 1e-06 s
Total time: 240.035 s
File: /Users/genevieb/Documents/GitHub/micro-sam/micro_sam/instance_segmentation.py
Function: _process_batch at line 365
Line # Hits Time Per Hit % Time Line Contents
==============================================================
365 @profile
366 def _process_batch(self, points, im_size, crop_box, original_size):
367 # run model on this batch
368 16 20809.0 1300.6 0.0 transformed_points = self._predictor.transform.apply_coords(points, im_size)
369 16 62775.0 3923.4 0.0 in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float)
370 16 26964.0 1685.2 0.0 in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
371 32 126742884.0 4e+06 52.8 masks, iou_preds, _ = self._predictor.predict_torch(
372 16 1792.0 112.0 0.0 in_points[:, None, :],
373 16 47.0 2.9 0.0 in_labels[:, None],
374 16 4.0 0.2 0.0 multimask_output=True,
375 16 1.0 0.1 0.0 return_logits=True,
376 )
377 16 113179665.0 7e+06 47.2 data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points)
378 16 42.0 2.6 0.0 del masks
379 16 2.0 0.1 0.0 return data
Total time: 245.755 s
File: /Users/genevieb/Documents/GitHub/micro-sam/micro_sam/instance_segmentation.py
Function: _process_crop at line 381
Line # Hits Time Per Hit % Time Line Contents
==============================================================
381 @profile
382 def _process_crop(self, image, crop_box, crop_layer_idx, verbose, precomputed_embeddings):
383 # crop the image and calculate embeddings
384 1 1.0 1.0 0.0 x0, y0, x1, y1 = crop_box
385 1 1.0 1.0 0.0 cropped_im = image[y0:y1, x0:x1, :]
386 1 0.0 0.0 0.0 cropped_im_size = cropped_im.shape[:2]
387
388 1 1.0 1.0 0.0 if not precomputed_embeddings:
389 self._predictor.set_image(cropped_im)
390
391 # get the points for this crop
392 1 7.0 7.0 0.0 points_scale = np.array(cropped_im_size)[None, ::-1]
393 1 176.0 176.0 0.0 points_for_image = self.point_grids[crop_layer_idx] * points_scale
394
395 # generate masks for this crop in batches
396 1 710.0 710.0 0.0 data = amg_utils.MaskData()
397 2 0.0 0.0 0.0 n_batches = len(points_for_image) // self._points_per_batch +\
398 1 0.0 0.0 0.0 int(len(points_for_image) % self._points_per_batch != 0)
399 18 21123.0 1173.5 0.0 for (points,) in tqdm(
400 1 134.0 134.0 0.0 amg_utils.batch_iterator(self._points_per_batch, points_for_image),
401 1 0.0 0.0 0.0 disable=not verbose, total=n_batches,
402 1 1.0 1.0 0.0 desc="Predict masks for point grid prompts",
403 ):
404 16 240049242.0 2e+07 97.7 batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size)
405 16 5672560.0 354535.0 2.3 data.cat(batch_data)
406 16 11417.0 713.6 0.0 del batch_data
407
408 1 0.0 0.0 0.0 if not precomputed_embeddings:
409 self._predictor.reset_image()
410
411 1 0.0 0.0 0.0 return data
Benchmarks from Constantin's
development/benchmarks.py
file: